{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1> Training flight delay model in BigQuery ML </h1>\n",
    "\n",
    "Run this notebook in Vertex Workbench. In this notebook, we will use BigQuery ML to train the same model that we did in Spark ML.\n",
    "\n",
    "Note how much easier this is ... (and also much more scaleable)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Verify dataset\n",
    "\n",
    "Let's make sure that we have the traindays and flights data in BigQuery. If you don't, please follow steps in the README.md in this directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Query complete after 0.05s: 100%|██████████| 2/2 [00:00<00:00, 383.69query/s]                         \n",
      "Downloading: 100%|██████████| 5/5 [00:01<00:00,  4.14rows/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>FL_DATE</th>\n",
       "      <th>UNIQUE_CARRIER</th>\n",
       "      <th>ORIGIN_AIRPORT_SEQ_ID</th>\n",
       "      <th>ORIGIN</th>\n",
       "      <th>DEST_AIRPORT_SEQ_ID</th>\n",
       "      <th>DEST</th>\n",
       "      <th>CRS_DEP_TIME</th>\n",
       "      <th>DEP_TIME</th>\n",
       "      <th>DEP_DELAY</th>\n",
       "      <th>TAXI_OUT</th>\n",
       "      <th>...</th>\n",
       "      <th>ARR_DELAY</th>\n",
       "      <th>CANCELLED</th>\n",
       "      <th>DIVERTED</th>\n",
       "      <th>DISTANCE</th>\n",
       "      <th>DEP_AIRPORT_LAT</th>\n",
       "      <th>DEP_AIRPORT_LON</th>\n",
       "      <th>DEP_AIRPORT_TZOFFSET</th>\n",
       "      <th>ARR_AIRPORT_LAT</th>\n",
       "      <th>ARR_AIRPORT_LON</th>\n",
       "      <th>ARR_AIRPORT_TZOFFSET</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2015-12-09</td>\n",
       "      <td>AS</td>\n",
       "      <td>1029904</td>\n",
       "      <td>ANC</td>\n",
       "      <td>1055102</td>\n",
       "      <td>BET</td>\n",
       "      <td>2015-12-10 04:00:00+00:00</td>\n",
       "      <td>2015-12-10 03:55:00+00:00</td>\n",
       "      <td>-5.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>...</td>\n",
       "      <td>-3.0</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>399.0</td>\n",
       "      <td>61.174167</td>\n",
       "      <td>-149.998056</td>\n",
       "      <td>-32400.0</td>\n",
       "      <td>60.778611</td>\n",
       "      <td>-161.837222</td>\n",
       "      <td>-32400.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2015-02-07</td>\n",
       "      <td>B6</td>\n",
       "      <td>1169703</td>\n",
       "      <td>FLL</td>\n",
       "      <td>1484304</td>\n",
       "      <td>SJU</td>\n",
       "      <td>2015-02-07 13:50:00+00:00</td>\n",
       "      <td>2015-02-07 14:23:00+00:00</td>\n",
       "      <td>33.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>...</td>\n",
       "      <td>22.0</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>1046.0</td>\n",
       "      <td>26.072500</td>\n",
       "      <td>-80.152778</td>\n",
       "      <td>-18000.0</td>\n",
       "      <td>18.439444</td>\n",
       "      <td>-66.002222</td>\n",
       "      <td>-14400.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2015-10-29</td>\n",
       "      <td>AS</td>\n",
       "      <td>1482803</td>\n",
       "      <td>SIT</td>\n",
       "      <td>1252303</td>\n",
       "      <td>JNU</td>\n",
       "      <td>2015-10-29 14:00:00+00:00</td>\n",
       "      <td>2015-10-29 13:57:00+00:00</td>\n",
       "      <td>-3.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>...</td>\n",
       "      <td>10.0</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>95.0</td>\n",
       "      <td>57.046944</td>\n",
       "      <td>-135.361111</td>\n",
       "      <td>-28800.0</td>\n",
       "      <td>58.354722</td>\n",
       "      <td>-134.574722</td>\n",
       "      <td>-28800.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2015-10-02</td>\n",
       "      <td>HA</td>\n",
       "      <td>1298202</td>\n",
       "      <td>LIH</td>\n",
       "      <td>1383002</td>\n",
       "      <td>OGG</td>\n",
       "      <td>2015-10-02 22:53:00+00:00</td>\n",
       "      <td>2015-10-02 22:58:00+00:00</td>\n",
       "      <td>5.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>...</td>\n",
       "      <td>8.0</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>201.0</td>\n",
       "      <td>21.976111</td>\n",
       "      <td>-159.338889</td>\n",
       "      <td>-36000.0</td>\n",
       "      <td>20.898611</td>\n",
       "      <td>-156.430556</td>\n",
       "      <td>-36000.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2015-06-30</td>\n",
       "      <td>AS</td>\n",
       "      <td>1075403</td>\n",
       "      <td>BRW</td>\n",
       "      <td>1470903</td>\n",
       "      <td>SCC</td>\n",
       "      <td>2015-07-01 04:20:00+00:00</td>\n",
       "      <td>2015-07-01 04:16:00+00:00</td>\n",
       "      <td>-4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>...</td>\n",
       "      <td>-8.0</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>204.0</td>\n",
       "      <td>71.284722</td>\n",
       "      <td>-156.768611</td>\n",
       "      <td>-28800.0</td>\n",
       "      <td>70.195556</td>\n",
       "      <td>-148.465833</td>\n",
       "      <td>-28800.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 25 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      FL_DATE UNIQUE_CARRIER ORIGIN_AIRPORT_SEQ_ID ORIGIN DEST_AIRPORT_SEQ_ID  \\\n",
       "0  2015-12-09             AS               1029904    ANC             1055102   \n",
       "1  2015-02-07             B6               1169703    FLL             1484304   \n",
       "2  2015-10-29             AS               1482803    SIT             1252303   \n",
       "3  2015-10-02             HA               1298202    LIH             1383002   \n",
       "4  2015-06-30             AS               1075403    BRW             1470903   \n",
       "\n",
       "  DEST              CRS_DEP_TIME                  DEP_TIME  DEP_DELAY  \\\n",
       "0  BET 2015-12-10 04:00:00+00:00 2015-12-10 03:55:00+00:00       -5.0   \n",
       "1  SJU 2015-02-07 13:50:00+00:00 2015-02-07 14:23:00+00:00       33.0   \n",
       "2  JNU 2015-10-29 14:00:00+00:00 2015-10-29 13:57:00+00:00       -3.0   \n",
       "3  OGG 2015-10-02 22:53:00+00:00 2015-10-02 22:58:00+00:00        5.0   \n",
       "4  SCC 2015-07-01 04:20:00+00:00 2015-07-01 04:16:00+00:00       -4.0   \n",
       "\n",
       "   TAXI_OUT  ... ARR_DELAY CANCELLED  DIVERTED DISTANCE DEP_AIRPORT_LAT  \\\n",
       "0      10.0  ...      -3.0     False     False    399.0       61.174167   \n",
       "1      13.0  ...      22.0     False     False   1046.0       26.072500   \n",
       "2       8.0  ...      10.0     False     False     95.0       57.046944   \n",
       "3       9.0  ...       8.0     False     False    201.0       21.976111   \n",
       "4       5.0  ...      -8.0     False     False    204.0       71.284722   \n",
       "\n",
       "   DEP_AIRPORT_LON  DEP_AIRPORT_TZOFFSET  ARR_AIRPORT_LAT  ARR_AIRPORT_LON  \\\n",
       "0      -149.998056              -32400.0        60.778611      -161.837222   \n",
       "1       -80.152778              -18000.0        18.439444       -66.002222   \n",
       "2      -135.361111              -28800.0        58.354722      -134.574722   \n",
       "3      -159.338889              -36000.0        20.898611      -156.430556   \n",
       "4      -156.768611              -28800.0        70.195556      -148.465833   \n",
       "\n",
       "   ARR_AIRPORT_TZOFFSET  \n",
       "0              -32400.0  \n",
       "1              -14400.0  \n",
       "2              -28800.0  \n",
       "3              -36000.0  \n",
       "4              -28800.0  \n",
       "\n",
       "[5 rows x 25 columns]"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%bigquery\n",
    "SELECT * FROM dsongcp.flights_tzcorr\n",
    "LIMIT 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Query complete after 0.00s: 100%|██████████| 1/1 [00:00<00:00, 383.64query/s]                          \n",
      "Downloading: 100%|██████████| 5/5 [00:01<00:00,  4.52rows/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>FL_DATE</th>\n",
       "      <th>is_train_day</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2015-01-01</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2015-01-02</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2015-01-03</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2015-01-04</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2015-01-05</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      FL_DATE is_train_day\n",
       "0  2015-01-01         True\n",
       "1  2015-01-02        False\n",
       "2  2015-01-03        False\n",
       "3  2015-01-04         True\n",
       "4  2015-01-05         True"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%bigquery\n",
    "SELECT * FROM dsongcp.trainday\n",
    "LIMIT 5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Logistic regression\n",
    "\n",
    "Let's use SQL to craft the dataset just the way we want it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Query complete after 0.00s: 100%|██████████| 3/3 [00:00<00:00, 1109.31query/s]                        \n",
      "Downloading: 100%|██████████| 5/5 [00:01<00:00,  4.66rows/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ontime</th>\n",
       "      <th>dep_delay</th>\n",
       "      <th>taxi_out</th>\n",
       "      <th>distance</th>\n",
       "      <th>is_eval_day</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>ontime</td>\n",
       "      <td>-5.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>399.0</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>late</td>\n",
       "      <td>33.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>1046.0</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>ontime</td>\n",
       "      <td>-3.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>95.0</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>ontime</td>\n",
       "      <td>5.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>201.0</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>ontime</td>\n",
       "      <td>-4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>204.0</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   ontime  dep_delay  taxi_out  distance  is_eval_day\n",
       "0  ontime       -5.0      10.0     399.0         True\n",
       "1    late       33.0      13.0    1046.0        False\n",
       "2  ontime       -3.0       8.0      95.0        False\n",
       "3  ontime        5.0       9.0     201.0         True\n",
       "4  ontime       -4.0       5.0     204.0        False"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%bigquery\n",
    "SELECT\n",
    "  IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n",
    "  dep_delay,\n",
    "  taxi_out,\n",
    "  distance,\n",
    "  IF(is_train_day = 'True', False, True) AS is_eval_day\n",
    "FROM dsongcp.flights_tzcorr f\n",
    "JOIN dsongcp.trainday t\n",
    "ON f.FL_DATE = t.FL_DATE\n",
    "WHERE\n",
    "  f.CANCELLED = False AND\n",
    "  f.DIVERTED = False\n",
    "LIMIT 5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Train the model:\n",
    "* ontime column in the label\n",
    "* the model is a logistic regression\n",
    "* use the is_eval_day column to split the data\n",
    "* remaining columns are features\n",
    "\n",
    "This will take about 10 minutes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Query complete after 0.00s: 100%|██████████| 3/3 [00:00<00:00, 1061.58query/s]                        \n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "Empty DataFrame\n",
       "Columns: []\n",
       "Index: []"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%bigquery\n",
    "CREATE OR REPLACE MODEL dsongcp.arr_delay_lm\n",
    "OPTIONS(input_label_cols=['ontime'], \n",
    "        model_type='logistic_reg', \n",
    "        data_split_method='custom',\n",
    "        data_split_col='is_eval_day')\n",
    "AS\n",
    "\n",
    "SELECT\n",
    "  IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n",
    "  dep_delay,\n",
    "  taxi_out,\n",
    "  distance,\n",
    "  IF(is_train_day = 'True', False, True) AS is_eval_day\n",
    "FROM dsongcp.flights_tzcorr f\n",
    "JOIN dsongcp.trainday t\n",
    "ON f.FL_DATE = t.FL_DATE\n",
    "WHERE\n",
    "  f.CANCELLED = False AND\n",
    "  f.DIVERTED = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Query complete after 0.00s: 100%|██████████| 1/1 [00:00<00:00, 644.29query/s] \n",
      "Downloading: 100%|██████████| 20/20 [00:01<00:00, 18.82rows/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>training_run</th>\n",
       "      <th>iteration</th>\n",
       "      <th>loss</th>\n",
       "      <th>eval_loss</th>\n",
       "      <th>learning_rate</th>\n",
       "      <th>duration_ms</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>19</td>\n",
       "      <td>0.000003</td>\n",
       "      <td>0.000004</td>\n",
       "      <td>104857.6</td>\n",
       "      <td>3306</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>18</td>\n",
       "      <td>0.000007</td>\n",
       "      <td>0.000007</td>\n",
       "      <td>52428.8</td>\n",
       "      <td>3350</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>17</td>\n",
       "      <td>0.000013</td>\n",
       "      <td>0.000012</td>\n",
       "      <td>26214.4</td>\n",
       "      <td>2941</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>16</td>\n",
       "      <td>0.000027</td>\n",
       "      <td>0.000020</td>\n",
       "      <td>13107.2</td>\n",
       "      <td>2921</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>15</td>\n",
       "      <td>0.000054</td>\n",
       "      <td>0.000033</td>\n",
       "      <td>6553.6</td>\n",
       "      <td>3370</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0</td>\n",
       "      <td>14</td>\n",
       "      <td>0.000108</td>\n",
       "      <td>0.000058</td>\n",
       "      <td>3276.8</td>\n",
       "      <td>2693</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>0.000216</td>\n",
       "      <td>0.000102</td>\n",
       "      <td>1638.4</td>\n",
       "      <td>3356</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0</td>\n",
       "      <td>12</td>\n",
       "      <td>0.000432</td>\n",
       "      <td>0.000182</td>\n",
       "      <td>819.2</td>\n",
       "      <td>3068</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>0</td>\n",
       "      <td>11</td>\n",
       "      <td>0.000865</td>\n",
       "      <td>0.000333</td>\n",
       "      <td>409.6</td>\n",
       "      <td>3291</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>0</td>\n",
       "      <td>10</td>\n",
       "      <td>0.001734</td>\n",
       "      <td>0.000625</td>\n",
       "      <td>204.8</td>\n",
       "      <td>3051</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>0</td>\n",
       "      <td>9</td>\n",
       "      <td>0.003484</td>\n",
       "      <td>0.001207</td>\n",
       "      <td>102.4</td>\n",
       "      <td>3412</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>0</td>\n",
       "      <td>8</td>\n",
       "      <td>0.007037</td>\n",
       "      <td>0.002421</td>\n",
       "      <td>51.2</td>\n",
       "      <td>3210</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>0</td>\n",
       "      <td>7</td>\n",
       "      <td>0.014348</td>\n",
       "      <td>0.005100</td>\n",
       "      <td>25.6</td>\n",
       "      <td>3731</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>0</td>\n",
       "      <td>6</td>\n",
       "      <td>0.029768</td>\n",
       "      <td>0.011594</td>\n",
       "      <td>12.8</td>\n",
       "      <td>3482</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "      <td>0.063422</td>\n",
       "      <td>0.029168</td>\n",
       "      <td>6.4</td>\n",
       "      <td>3982</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>0</td>\n",
       "      <td>4</td>\n",
       "      <td>0.136566</td>\n",
       "      <td>0.085203</td>\n",
       "      <td>3.2</td>\n",
       "      <td>4235</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>0.267786</td>\n",
       "      <td>0.225114</td>\n",
       "      <td>1.6</td>\n",
       "      <td>4171</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>0.426662</td>\n",
       "      <td>0.409772</td>\n",
       "      <td>0.8</td>\n",
       "      <td>3742</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.558553</td>\n",
       "      <td>0.555572</td>\n",
       "      <td>0.4</td>\n",
       "      <td>4570</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.644285</td>\n",
       "      <td>0.644522</td>\n",
       "      <td>0.2</td>\n",
       "      <td>3581</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    training_run  iteration      loss  eval_loss  learning_rate  duration_ms\n",
       "0              0         19  0.000003   0.000004       104857.6         3306\n",
       "1              0         18  0.000007   0.000007        52428.8         3350\n",
       "2              0         17  0.000013   0.000012        26214.4         2941\n",
       "3              0         16  0.000027   0.000020        13107.2         2921\n",
       "4              0         15  0.000054   0.000033         6553.6         3370\n",
       "5              0         14  0.000108   0.000058         3276.8         2693\n",
       "6              0         13  0.000216   0.000102         1638.4         3356\n",
       "7              0         12  0.000432   0.000182          819.2         3068\n",
       "8              0         11  0.000865   0.000333          409.6         3291\n",
       "9              0         10  0.001734   0.000625          204.8         3051\n",
       "10             0          9  0.003484   0.001207          102.4         3412\n",
       "11             0          8  0.007037   0.002421           51.2         3210\n",
       "12             0          7  0.014348   0.005100           25.6         3731\n",
       "13             0          6  0.029768   0.011594           12.8         3482\n",
       "14             0          5  0.063422   0.029168            6.4         3982\n",
       "15             0          4  0.136566   0.085203            3.2         4235\n",
       "16             0          3  0.267786   0.225114            1.6         4171\n",
       "17             0          2  0.426662   0.409772            0.8         3742\n",
       "18             0          1  0.558553   0.555572            0.4         4570\n",
       "19             0          0  0.644285   0.644522            0.2         3581"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%bigquery\n",
    "SELECT * FROM ML.TRAINING_INFO(MODEL dsongcp.arr_delay_lm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Query complete after 0.00s: 100%|██████████| 1/1 [00:00<00:00, 512.13query/s]                          \n",
      "Downloading: 100%|██████████| 4/4 [00:01<00:00,  3.72rows/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>processed_input</th>\n",
       "      <th>weight</th>\n",
       "      <th>category_weights</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>dep_delay</td>\n",
       "      <td>-0.132984</td>\n",
       "      <td>[]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>taxi_out</td>\n",
       "      <td>-0.121715</td>\n",
       "      <td>[]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>distance</td>\n",
       "      <td>0.000223</td>\n",
       "      <td>[]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>__INTERCEPT__</td>\n",
       "      <td>4.762572</td>\n",
       "      <td>[]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  processed_input    weight category_weights\n",
       "0       dep_delay -0.132984               []\n",
       "1        taxi_out -0.121715               []\n",
       "2        distance  0.000223               []\n",
       "3   __INTERCEPT__  4.762572               []"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%bigquery\n",
    "SELECT * FROM ML.WEIGHTS(MODEL dsongcp.arr_delay_lm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Query complete after 0.00s: 100%|██████████| 2/2 [00:00<00:00, 815.93query/s]                         \n",
      "Downloading: 100%|██████████| 1/1 [00:01<00:00,  1.12s/rows]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>predicted_ontime</th>\n",
       "      <th>predicted_ontime_probs</th>\n",
       "      <th>dep_delay</th>\n",
       "      <th>taxi_out</th>\n",
       "      <th>distance</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>ontime</td>\n",
       "      <td>[{'label': 'ontime', 'prob': 0.850350772376706...</td>\n",
       "      <td>12.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>1231</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  predicted_ontime                             predicted_ontime_probs  \\\n",
       "0           ontime  [{'label': 'ontime', 'prob': 0.850350772376706...   \n",
       "\n",
       "   dep_delay  taxi_out  distance  \n",
       "0       12.0      14.0      1231  "
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%bigquery\n",
    "SELECT * FROM ML.PREDICT(MODEL dsongcp.arr_delay_lm,\n",
    "                        (\n",
    "SELECT 12.0 AS dep_delay, 14.0 AS taxi_out, 1231 AS distance\n",
    "                        ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Query complete after 0.00s: 100%|██████████| 7/7 [00:00<00:00, 2886.65query/s]                        \n",
      "Downloading: 100%|██████████| 1/1 [00:02<00:00,  2.39s/rows]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>precision</th>\n",
       "      <th>recall</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>f1_score</th>\n",
       "      <th>log_loss</th>\n",
       "      <th>roc_auc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.964337</td>\n",
       "      <td>0.956535</td>\n",
       "      <td>0.935174</td>\n",
       "      <td>0.96042</td>\n",
       "      <td>0.167233</td>\n",
       "      <td>0.956248</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   precision    recall  accuracy  f1_score  log_loss   roc_auc\n",
       "0   0.964337  0.956535  0.935174   0.96042  0.167233  0.956248"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%bigquery\n",
    "SELECT * \n",
    "FROM ML.EVALUATE(MODEL dsongcp.arr_delay_lm,\n",
    "                 (\n",
    "                     \n",
    "SELECT\n",
    "  IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n",
    "  dep_delay,\n",
    "  taxi_out,\n",
    "  distance\n",
    "FROM dsongcp.flights_tzcorr f\n",
    "JOIN dsongcp.trainday t\n",
    "ON f.FL_DATE = t.FL_DATE\n",
    "WHERE\n",
    "  f.CANCELLED = False AND\n",
    "  f.DIVERTED = False AND\n",
    "    is_train_day = 'False'\n",
    "                     \n",
    "                 ),\n",
    "                 STRUCT(0.7 AS threshold))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Do same metrics as in Spark code\n",
    "\n",
    "We are using ML.PREDICT and computing the necessary stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Query complete after 0.00s: 100%|██████████| 4/4 [00:00<00:00, 1431.14query/s]                        \n",
      "Downloading: 100%|██████████| 1/1 [00:01<00:00,  1.14s/rows]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>correct_cancel</th>\n",
       "      <th>total_noncancel</th>\n",
       "      <th>correct_noncancel</th>\n",
       "      <th>total_cancel</th>\n",
       "      <th>rmse</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.836363</td>\n",
       "      <td>1301948</td>\n",
       "      <td>0.964337</td>\n",
       "      <td>283750</td>\n",
       "      <td>0.213091</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   correct_cancel  total_noncancel  correct_noncancel  total_cancel      rmse\n",
       "0        0.836363          1301948           0.964337        283750  0.213091"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%bigquery\n",
    "\n",
    "WITH predictions AS (\n",
    "SELECT \n",
    "  *\n",
    "FROM ML.PREDICT(MODEL dsongcp.arr_delay_lm,\n",
    "                 (\n",
    "SELECT\n",
    "  IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n",
    "  dep_delay,\n",
    "  taxi_out,\n",
    "  distance\n",
    "FROM dsongcp.flights_tzcorr f\n",
    "JOIN dsongcp.trainday t\n",
    "ON f.FL_DATE = t.FL_DATE\n",
    "WHERE\n",
    "  f.CANCELLED = False AND \n",
    "  f.DIVERTED = False AND\n",
    "  t.is_train_day = 'False'\n",
    "                 ),\n",
    "                 STRUCT(0.7 AS threshold))),\n",
    "\n",
    "stats AS (\n",
    "SELECT \n",
    "  COUNTIF(ontime != 'ontime' AND ontime = predicted_ontime) AS correct_cancel\n",
    "  , COUNTIF(predicted_ontime = 'ontime') AS total_noncancel\n",
    "  , COUNTIF(ontime = 'ontime' AND ontime = predicted_ontime) AS correct_noncancel\n",
    "  , COUNTIF(ontime != 'ontime') AS total_cancel\n",
    "  , SQRT(SUM((IF(ontime = 'ontime', 1, 0) - p.prob) * (IF(ontime = 'ontime', 1, 0) - p.prob))/COUNT(*)) AS rmse\n",
    "FROM predictions, UNNEST(predicted_ontime_probs) p\n",
    "WHERE p.label = 'ontime'\n",
    ")\n",
    "\n",
    "SELECT\n",
    "   correct_cancel / total_cancel AS correct_cancel\n",
    "   , total_noncancel\n",
    "   , correct_noncancel / total_noncancel AS correct_noncancel\n",
    "   , total_cancel\n",
    "   , rmse\n",
    "FROM stats"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Add airport info\n",
    "\n",
    "Add airport information to model (note two additional columns: origin and dest). This seemingly simple change adds two categorical variables that, when one-hot-encoded, adds 600+ new columns to the model. BigQuery ML doesn't skip a beat ...\n",
    "\n",
    "This query will take ~10 minutes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Query complete after 0.00s: 100%|██████████| 3/3 [00:00<00:00, 1256.16query/s]                        \n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "Empty DataFrame\n",
       "Columns: []\n",
       "Index: []"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%bigquery\n",
    "CREATE OR REPLACE MODEL dsongcp.arr_delay_airports_lm\n",
    "OPTIONS(input_label_cols=['ontime'], \n",
    "        model_type='logistic_reg', \n",
    "        data_split_method='custom',\n",
    "        data_split_col='is_eval_day')\n",
    "AS\n",
    "\n",
    "SELECT\n",
    "  IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n",
    "  dep_delay,\n",
    "  taxi_out,\n",
    "  distance,\n",
    "  origin,\n",
    "  dest,\n",
    "  IF(is_train_day = 'True', False, True) AS is_eval_day\n",
    "FROM dsongcp.flights_tzcorr f\n",
    "JOIN dsongcp.trainday t\n",
    "ON f.FL_DATE = t.FL_DATE\n",
    "WHERE\n",
    "  f.CANCELLED = False AND \n",
    "  f.DIVERTED = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Query complete after 0.00s: 100%|██████████| 7/7 [00:00<00:00, 2760.97query/s]                        \n",
      "Downloading: 100%|██████████| 1/1 [00:01<00:00,  1.11s/rows]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>precision</th>\n",
       "      <th>recall</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>f1_score</th>\n",
       "      <th>log_loss</th>\n",
       "      <th>roc_auc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.967151</td>\n",
       "      <td>0.957706</td>\n",
       "      <td>0.938477</td>\n",
       "      <td>0.962405</td>\n",
       "      <td>0.165557</td>\n",
       "      <td>0.960821</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   precision    recall  accuracy  f1_score  log_loss   roc_auc\n",
       "0   0.967151  0.957706  0.938477  0.962405  0.165557  0.960821"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%bigquery\n",
    "SELECT * \n",
    "FROM ML.EVALUATE(MODEL dsongcp.arr_delay_airports_lm,\n",
    "                 (\n",
    "SELECT\n",
    "  IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n",
    "  dep_delay,\n",
    "  taxi_out,\n",
    "  distance,\n",
    "  origin,\n",
    "  dest,\n",
    "  IF(is_train_day = 'True', False, True) AS is_eval_day\n",
    "FROM dsongcp.flights_tzcorr f\n",
    "JOIN dsongcp.trainday t\n",
    "ON f.FL_DATE = t.FL_DATE\n",
    "WHERE\n",
    "  f.CANCELLED = False AND \n",
    "  f.DIVERTED = False AND\n",
    "  t.is_train_day = 'False'\n",
    "                 ),\n",
    "                 STRUCT(0.7 AS threshold))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Query complete after 0.00s: 100%|██████████| 4/4 [00:00<00:00, 934.66query/s]                         \n",
      "Downloading: 100%|██████████| 1/1 [00:00<00:00,  1.01rows/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>correct_cancel</th>\n",
       "      <th>total_noncancel</th>\n",
       "      <th>correct_noncancel</th>\n",
       "      <th>total_cancel</th>\n",
       "      <th>rmse</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.84953</td>\n",
       "      <td>1299749</td>\n",
       "      <td>0.967151</td>\n",
       "      <td>283750</td>\n",
       "      <td>0.209839</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   correct_cancel  total_noncancel  correct_noncancel  total_cancel      rmse\n",
       "0         0.84953          1299749           0.967151        283750  0.209839"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%bigquery\n",
    "\n",
    "WITH predictions AS (\n",
    "SELECT \n",
    "  *\n",
    "FROM ML.PREDICT(MODEL dsongcp.arr_delay_airports_lm,\n",
    "                 (\n",
    "SELECT\n",
    "  IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n",
    "  dep_delay,\n",
    "  taxi_out,\n",
    "  distance,\n",
    "  origin,\n",
    "  dest,\n",
    "  IF(is_train_day = 'True', False, True) AS is_eval_day\n",
    "FROM dsongcp.flights_tzcorr f\n",
    "JOIN dsongcp.trainday t\n",
    "ON f.FL_DATE = t.FL_DATE\n",
    "WHERE\n",
    "  f.CANCELLED = False AND \n",
    "  f.DIVERTED = False AND\n",
    "  t.is_train_day = 'False'\n",
    "                 ),\n",
    "                 STRUCT(0.7 AS threshold))),\n",
    "\n",
    "stats AS (\n",
    "SELECT \n",
    "  COUNTIF(ontime != 'ontime' AND ontime = predicted_ontime) AS correct_cancel\n",
    "  , COUNTIF(predicted_ontime = 'ontime') AS total_noncancel\n",
    "  , COUNTIF(ontime = 'ontime' AND ontime = predicted_ontime) AS correct_noncancel\n",
    "  , COUNTIF(ontime != 'ontime') AS total_cancel\n",
    "  , SQRT(SUM((IF(ontime = 'ontime', 1, 0) - p.prob) * (IF(ontime = 'ontime', 1, 0) - p.prob))/COUNT(*)) AS rmse\n",
    "FROM predictions, UNNEST(predicted_ontime_probs) p\n",
    "WHERE p.label = 'ontime'\n",
    ")\n",
    "\n",
    "SELECT\n",
    "   correct_cancel / total_cancel AS correct_cancel\n",
    "   , total_noncancel\n",
    "   , correct_noncancel / total_noncancel AS correct_noncancel\n",
    "   , total_cancel\n",
    "   , rmse\n",
    "FROM stats"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that the addition of the airports information has improved both the AUC and the RMSE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright 2019-2021 Google Inc. Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
   ]
  }
 ],
 "metadata": {
  "environment": {
   "kernel": "python3",
   "name": "managed-notebooks.m82",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/base-cu110:latest"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
